import argparse
import torch

class Options(object):
    def __init__(self):
        self.parser = argparse.ArgumentParser(description='[TII] Experiment')
        '''
        ----------------------------------------------------------------------------------------------
        common config for all data
        '''
        self.parser.add_argument('--root_path', type=str, default='../data', help='root path of the data file')
        self.parser.add_argument('--data_path', type=str, default='BearingsDetection',
                            help='Monash/Monash_UEA_UCR_Regression_Archive -> [AppliancesEnergy][BeijingPM25Quality][BeijingPM10Quality][BenzeneConcentration][IEEEPPG][LiveFuelMoistureContent]'
                                 'Windpower/raw_windpower/dataset -> [JSDF001]'
                                 'Isdb -> isdb'
                                 'UEA/Multivariate2018_ts->Heartbeat'
                                 'BearingsDetection->bearingdetection')
        self.parser.add_argument('--data_name', type=str, default='bearingdetection', help='data_name')
        self.parser.add_argument('--num_workers', type=int, default=0, help='dataloader num workers')
        '''
        -----------------------------------------------------------------------------------------------
        config for ISDB classification
        '''
        self.parser.add_argument('--test_mode', type=str, default='test_26', help='Whether to use 26 loops to test or''randomly select N loops to test')
        self.parser.add_argument('--test_n', type=int, default=15, help='how many loops are used to test in each class if test_mode is test_random')
        '''
        -----------------------------------------------------------------------------------------------
        config for Monash Regression
        '''
        self.parser.add_argument('--subsample_factor', type=int, default=None, help='Sub-sampling factor used for long sequences: keep every kth sample')
        self.parser.add_argument('--dataencodingtype', type=str, default='utf-8', help='using ISO-8859-1 for specific Datset: BIDMC32HR，BIDMC32RR，and BIDMC32SpO2')
        '''
        -----------------------------------------------------------------------------------------------
        config for Powerdata forecasting
        '''
        self.parser.add_argument('--split_step', type=int, default=8, help='Index Difference between each training')
        self.parser.add_argument('--input_length', type=int, default=960, help='input sequence length of forecasting model')
        self.parser.add_argument('--input_cols', default=['all'], help='which features to use, default [all] or'
                                                                  'a list includs col_name, such as [feat1, feat2, ...]')
        self.parser.add_argument('--output_length', type=int, default=672, help='output sequence length of forecasting model')

        '''
        -----------------------------------------------------------------------------------------------
        config for Experiment
        '''
        self.parser.add_argument('--exp_seed', help='Seed used for experiments')
        self.parser.add_argument('--exp_encoder', type=str, default='cnn', help='which type encoder is used')
        self.parser.add_argument('--exp_task', choices={"classification", "regression"}, default="regression")


        self.parser.add_argument('--exp_savepath', type=str, default='./save', help='save path')
        self.parser.add_argument('--exp_segment_savepath', type=str, default='./save/segment_results', help='save path')
        self.parser.add_argument('--exp_batchsize', type=int, default=64, help='batch size of train input data')
        self.parser.add_argument('--exp_epochs', type=int, default=75, help='Number of training epochs')
        self.parser.add_argument('--exp_lr', type=float, default=1e-3, help='learning rate (default holds for batch size 64)')
        self.parser.add_argument('--global_reg', action='store_true', help='If set, L2 regularization will be applied to all weights instead of only the output layer')
        self.parser.add_argument('--l2_reg', type=float, default=0, help='L2 weight regularization parameter')
        self.parser.add_argument('--model_savepath', type=str, default='./save/model', help='path to save model')
        self.parser.add_argument('--testonly', action='store_true',
                            help='If set, no training will take place; instead, trained model will be loaded and evaluated on test set')


        '''
        -----------------------------------------------------------------------------------------------
        config for time series segmentation
        '''
        self.parser.add_argument('--seg_distance', type=str, default='cosine', help='which distance is used')

        '''
        -----------------------------------------------------------------------------------------------
        config for transformer model
        '''
        # parser.add_argument('--num_classes', type=int, default=1, help='number of the output variable')
        self.parser.add_argument('--embedding_dim', type=int, default=128, help='dimension of the embedding')
        self.parser.add_argument('--ts_maxlen', type=int,
                            help='Max length of series, default is the true length of the sequence')
        self.parser.add_argument('--pos_encoding', choices={'fixed', 'learnable'}, default='fixed',
                            help='Internal dimension of transformer embeddings')
        self.parser.add_argument('--dropout', type=float, default=0.1,
                            help='Dropout applied to most transformer encoder layers')
        # transformer
        self.parser.add_argument('--num_layers', type=int, default=3, help='Number of transformer encoder layers (blocks)')
        self.parser.add_argument('--num_heads', type=int, default=8, help='Number of transformer heads')
        self.parser.add_argument('--normalization_layer', choices={'BatchNorm', 'LayerNorm'}, default='BatchNorm',
                            help='Normalization layer to be used internally in transformer encoder')
        self.parser.add_argument('--dim_feedforward', type=int, default=256,
                            help='Dimension of dense feedforward part of transformer layer')
        self.parser.add_argument('--transformer_activation', choices={'relu', 'gelu'}, default='gelu',
                            help='Activation to be used in transformer encoder')
        '''
        -----------------------------------------------------------------------------------------------
        config for cnn model
        '''
        self.parser.add_argument('--channels', type=int, default=16, help='Number of channels manipulated in the causal CNN')
        self.parser.add_argument('--depth', type=int, default=5, help='Depth of the causal CNN')
        self.parser.add_argument('--reduced_size', type=int, default=160, help='Fixed length to which the output time series of the causal CNN is reduced')
        self.parser.add_argument('--out_channels', type=int, default=64, help='Out_channels Number of output channels')
        self.parser.add_argument('--kernel_size', type=int, default=3, help='Kernel size of the applied non-residual convolutions')

        '''
        -----------------------------------------------------------------------------------------------
        config for GPU
        '''
        self.parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
        self.parser.add_argument('--device', type=str, default='cuda:0', help='cuda:0 or cpu')
        '''
        -----------------------------------------------------------------------------------------------
        config other
        '''

    def parse(self):
        args = self.parser.parse_args()
        # Check GPU config
        args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
        if args.use_gpu:
            args.device = torch.device("cuda:0")
        else:
            args.device = torch.device("cpu")

        if args.data_path in ['Monash/Monash_UEA_UCR_Regression_Archive']:
            args.exp_task = 'regression'
        elif args.data_path in ['Isdb', 'UEA/Multivariate2018_ts', 'BearingsDetection']:
            args.exp_task = 'classification'

        return args

